/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tomcat.websocket.server;

import org.apache.catalina.connector.RequestFacade;
import org.apache.tomcat.util.codec.binary.Base64;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.util.security.ConcurrentMessageDigest;
import org.apache.tomcat.websocket.Constants;
import org.apache.tomcat.websocket.*;
import org.apache.tomcat.websocket.pojo.PojoEndpointServer;

import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletRequestWrapper;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.ServerEndpointConfig;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.Map.Entry;

public class UpgradeUtil {

	private static final StringManager sm = StringManager
			.getManager(org.apache.tomcat.websocket.server.Constants.PACKAGE_NAME);
	private static final byte[] WS_ACCEPT =
			"258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
					StandardCharsets.ISO_8859_1);

	private UpgradeUtil() {
		// Utility class. Hide default constructor.
	}

	/**
	 * Checks to see if this is an HTTP request that includes a valid upgrade
	 * request to web socket.
	 * <p>
	 * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
	 * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
	 * 6455 section 4.1 requires that a WebSocket Upgrade uses GET.
	 */
	public static boolean isWebSocketUpgradeRequest(ServletRequest request,
	                                                ServletResponse response) {

		return ((request instanceof HttpServletRequest) &&
				(response instanceof HttpServletResponse) &&
				headerContainsToken((HttpServletRequest) request,
						Constants.UPGRADE_HEADER_NAME,
						Constants.UPGRADE_HEADER_VALUE) &&
				"GET".equals(((HttpServletRequest) request).getMethod()));
	}

	public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
	                             HttpServletResponse resp, ServerEndpointConfig sec,
	                             Map<String, String> pathParams)
			throws ServletException, IOException {

		// Validate the rest of the headers and reject the request if that
		// validation fails
		String key;
		String subProtocol = null;
		if (!headerContainsToken(req, Constants.CONNECTION_HEADER_NAME,
				Constants.CONNECTION_HEADER_VALUE)) {
			resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
			return;
		}
		if (!headerContainsToken(req, Constants.WS_VERSION_HEADER_NAME,
				Constants.WS_VERSION_HEADER_VALUE)) {
			resp.setStatus(426);
			resp.setHeader(Constants.WS_VERSION_HEADER_NAME,
					Constants.WS_VERSION_HEADER_VALUE);
			return;
		}
		key = req.getHeader(Constants.WS_KEY_HEADER_NAME);
		if (key == null) {
			resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
			return;
		}


		// Origin check
		String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
		if (!sec.getConfigurator().checkOrigin(origin)) {
			resp.sendError(HttpServletResponse.SC_FORBIDDEN);
			return;
		}
		// Sub-protocols
		List<String> subProtocols = getTokensFromHeader(req,
				Constants.WS_PROTOCOL_HEADER_NAME);
		subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
				sec.getSubprotocols(), subProtocols);

		// Extensions
		// Should normally only be one header but handle the case of multiple
		// headers
		List<Extension> extensionsRequested = new ArrayList<Extension>();
		Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
		while (extHeaders.hasMoreElements()) {
			Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
		}
		// Negotiation phase 1. By default this simply filters out the
		// extensions that the server does not support but applications could
		// use a custom configurator to do more than this.
		List<Extension> installedExtensions = null;
		if (sec.getExtensions().size() == 0) {
			installedExtensions = Constants.INSTALLED_EXTENSIONS;
		} else {
			installedExtensions = new ArrayList<Extension>();
			installedExtensions.addAll(sec.getExtensions());
			installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
		}
		List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
				installedExtensions, extensionsRequested);

		// Negotiation phase 2. Create the Transformations that will be applied
		// to this connection. Note than an extension may be dropped at this
		// point if the client has requested a configuration that the server is
		// unable to support.
		List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);

		List<Extension> negotiatedExtensionsPhase2;
		if (transformations.isEmpty()) {
			negotiatedExtensionsPhase2 = Collections.emptyList();
		} else {
			negotiatedExtensionsPhase2 = new ArrayList<Extension>(transformations.size());
			for (Transformation t : transformations) {
				negotiatedExtensionsPhase2.add(t.getExtensionResponse());
			}
		}

		// Build the transformation pipeline
		Transformation transformation = null;
		StringBuilder responseHeaderExtensions = new StringBuilder();
		boolean first = true;
		for (Transformation t : transformations) {
			if (first) {
				first = false;
			} else {
				responseHeaderExtensions.append(',');
			}
			append(responseHeaderExtensions, t.getExtensionResponse());
			if (transformation == null) {
				transformation = t;
			} else {
				transformation.setNext(t);
			}
		}

		// Now we have the full pipeline, validate the use of the RSV bits.
		if (transformation != null && !transformation.validateRsvBits(0)) {
			throw new ServletException(sm.getString("upgradeUtil.incompatibleRsv"));
		}

		// If we got this far, all is good. Accept the connection.
		resp.setHeader(Constants.UPGRADE_HEADER_NAME,
				Constants.UPGRADE_HEADER_VALUE);
		resp.setHeader(Constants.CONNECTION_HEADER_NAME,
				Constants.CONNECTION_HEADER_VALUE);
		resp.setHeader(HandshakeResponse.SEC_WEBSOCKET_ACCEPT,
				getWebSocketAccept(key));
		if (subProtocol != null && subProtocol.length() > 0) {
			// RFC6455 4.2.2 explicitly states "" is not valid here
			resp.setHeader(Constants.WS_PROTOCOL_HEADER_NAME, subProtocol);
		}
		if (!transformations.isEmpty()) {
			resp.setHeader(Constants.WS_EXTENSIONS_HEADER_NAME, responseHeaderExtensions.toString());
		}

		WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
		WsHandshakeResponse wsResponse = new WsHandshakeResponse();
		WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
				new WsPerSessionServerEndpointConfig(sec);
		sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
				wsRequest, wsResponse);
		wsRequest.finished();

		// Add any additional headers
		for (Entry<String, List<String>> entry :
				wsResponse.getHeaders().entrySet()) {
			for (String headerValue : entry.getValue()) {
				resp.addHeader(entry.getKey(), headerValue);
			}
		}

		Endpoint ep;
		try {
			Class<?> clazz = sec.getEndpointClass();
			if (Endpoint.class.isAssignableFrom(clazz)) {
				ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
						clazz);
			} else {
				ep = new PojoEndpointServer();
				// Need to make path params available to POJO
				perSessionServerEndpointConfig.getUserProperties().put(
						PojoEndpointServer.POJO_PATH_PARAM_KEY, pathParams);
			}
		} catch (InstantiationException e) {
			throw new ServletException(e);
		}

		// Small hack until the Servlet API provides a way to do this.
		ServletRequest inner = req;
		// Unwrap the request
		while (inner instanceof ServletRequestWrapper) {
			inner = ((ServletRequestWrapper) inner).getRequest();
		}
		if (inner instanceof RequestFacade) {
			WsHttpUpgradeHandler wsHandler =
					((RequestFacade) inner).upgrade(WsHttpUpgradeHandler.class);
			wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
					negotiatedExtensionsPhase2, subProtocol, transformation, pathParams,
					req.isSecure());
		} else {
			throw new ServletException("Upgrade failed");
		}
	}

	private static List<Transformation> createTransformations(
			List<Extension> negotiatedExtensions) {

		TransformationFactory factory = TransformationFactory.getInstance();

		LinkedHashMap<String, List<List<Extension.Parameter>>> extensionPreferences =
				new LinkedHashMap<String, List<List<Extension.Parameter>>>();

		// Result will likely be smaller than this
		List<Transformation> result = new ArrayList<Transformation>(negotiatedExtensions.size());

		for (Extension extension : negotiatedExtensions) {
			List<List<Extension.Parameter>> preferences =
					extensionPreferences.get(extension.getName());

			if (preferences == null) {
				preferences = new ArrayList<List<Extension.Parameter>>();
				extensionPreferences.put(extension.getName(), preferences);
			}

			preferences.add(extension.getParameters());
		}

		for (Map.Entry<String, List<List<Extension.Parameter>>> entry :
				extensionPreferences.entrySet()) {
			Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
			if (transformation != null) {
				result.add(transformation);
			}
		}
		return result;
	}

	private static void append(StringBuilder sb, Extension extension) {
		if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
			return;
		}

		sb.append(extension.getName());

		for (Extension.Parameter p : extension.getParameters()) {
			sb.append(';');
			sb.append(p.getName());
			if (p.getValue() != null) {
				sb.append('=');
				sb.append(p.getValue());
			}
		}
	}

	/*
	 * This only works for tokens. Quoted strings need more sophisticated
	 * parsing.
	 */
	private static boolean headerContainsToken(HttpServletRequest req,
	                                           String headerName, String target) {
		Enumeration<String> headers = req.getHeaders(headerName);
		while (headers.hasMoreElements()) {
			String header = headers.nextElement();
			String[] tokens = header.split(",");
			for (String token : tokens) {
				if (target.equalsIgnoreCase(token.trim())) {
					return true;
				}
			}
		}
		return false;
	}

	/*
	 * This only works for tokens. Quoted strings need more sophisticated
	 * parsing.
	 */
	private static List<String> getTokensFromHeader(HttpServletRequest req,
	                                                String headerName) {
		List<String> result = new ArrayList<String>();
		Enumeration<String> headers = req.getHeaders(headerName);
		while (headers.hasMoreElements()) {
			String header = headers.nextElement();
			String[] tokens = header.split(",");
			for (String token : tokens) {
				result.add(token.trim());
			}
		}
		return result;
	}

	private static String getWebSocketAccept(String key) {
		byte[] digest = ConcurrentMessageDigest.digestSHA1(
				key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
		return Base64.encodeBase64String(digest);
	}
}
